Paper Name: DiffDis: Empowering Generative Diffusion Model with Cross-Modal Discrimination Capability
Project Members: Furkan Genç, Barış Sarper Tezcan
In [ ]:
# define the constants
WIDTH = 512
HEIGHT = 512
LATENTS_WIDTH = WIDTH // 8
LATENTS_HEIGHT = HEIGHT // 8
BATCH_SIZE = 1
root_dir = "../dataset/cc3m/train"
# training parameters
num_train_epochs = 6
Lambda = 1.0
save_steps = 5000
# optimizer parameters
learning_rate = 1e-5
discriminative_learning_rate = 1e-4 # New learning rate for discriminative tasks
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_weight_decay = 1e-4
adam_epsilon = 1e-8
# IMAGE TO TEXT
test_dataset = "CIFAR10" # Set to "CIFAR100" to use CIFAR100 dataset
# output directory
train_output_dir = "../results/output_1"
test_output_dir = "../results/" + test_dataset
inference_output_dir = "../results/text_to_image/output_1/last"
# Load the models
model_file = "data/v1-5-pruned.ckpt"
train_unet_file = None # Set to None to finetune from scratch, if specified, the diffusion model will be loaded from this file
test_unet_file = "../results/output_1/last.pt"
inference_unet_file = "../results/output_1/last.pt"
# EMA parameters
use_ema = False # Set to True to use EMA
ema_decay = 0.9999
warmup_steps = 1000
# TEXT TO IMAGE
prompt1 = "A river with boats docked and houses in the background"
prompt2 = "A piece of chocolate swirled cake on a plate"
prompt3 = "A large bed sitting next to a small Christmas Tree surrounded by pictures"
prompt4 = "A bear searching for food near the river"
prompts = [prompt1, prompt2, prompt3, prompt4]
uncond_prompt = "" # Also known as negative prompt
do_cfg = True
cfg_scale = 3 # min: 1, max: 14
num_samples = 1
# SAMPLER
sampler = "ddpm"
num_inference_steps = 50
seed = 42
In [ ]:
import torch
import torch.nn.functional as F
import os
from tqdm import tqdm
from ddpm import DDPMSampler
from pipeline import get_time_embedding
from dataloader import train_dataloader
import model_loader
import time
from diffusion import TransformerBlock, UNet_Transformer # Ensure these are correctly imported
import pipeline
from PIL import Image
from pathlib import Path
from transformers import CLIPTokenizer
# Set the device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Load the models
models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)
ddpm = DDPMSampler(generator=None)
if train_unet_file is not None:
# Load the UNet model
print(f"Loading UNet model from {train_unet_file}")
models['diffusion'].load_state_dict(torch.load(train_unet_file)['model_state_dict'])
if 'best_loss' in torch.load(train_unet_file):
best_loss = torch.load(train_unet_file)['best_loss']
best_step = torch.load(train_unet_file)['best_step']
last_loss = torch.load(train_unet_file)['last_loss']
last_step = torch.load(train_unet_file)['last_step']
else:
best_loss = float('inf')
best_step = 0
last_loss = 0.0
last_step = 0
else:
best_loss = float('inf')
best_step = 0
last_loss = 0.0
last_step = 0
# TEXT TO IMAGE
tokenizer = CLIPTokenizer("./data/vocab.json", merges_file="./data/merges.txt")
# Disable gradient computations for the models['encoder'], DDPM, and models['clip'] models
for param in models['encoder'].parameters():
param.requires_grad = False
for param in models['clip'].parameters():
param.requires_grad = False
# Set the models['encoder'] and models['clip'] to eval mode
models['encoder'].eval()
models['clip'].eval()
# Separate parameters for discriminative tasks
discriminative_params = []
non_discriminative_params = []
for name, param in models['diffusion'].named_parameters():
if isinstance(getattr(models['diffusion'], name.split('.')[0], None), (TransformerBlock, UNet_Transformer)):
discriminative_params.append(param)
else:
non_discriminative_params.append(param)
# AdamW optimizer with separate learning rates
optimizer = torch.optim.AdamW([
{'params': non_discriminative_params, 'lr': learning_rate},
{'params': discriminative_params, 'lr': discriminative_learning_rate}
], betas=(adam_beta1, adam_beta2), weight_decay=adam_weight_decay, eps=adam_epsilon)
if train_unet_file is not None:
print(f"Loading optimizer state from {train_unet_file}")
optimizer.load_state_dict(torch.load(train_unet_file)['optimizer_state_dict'])
# Linear warmup scheduler for non-discriminative parameters
def warmup_lr_lambda(current_step: int):
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
return 1.0
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[
warmup_lr_lambda, # Apply warmup for non-discriminative params
lambda step: 1.0 # Keep constant learning rate for discriminative params
])
# EMA setup
if use_ema:
ema_unet = torch.optim.swa_utils.AveragedModel(models['diffusion'], avg_fn=lambda averaged_model_parameter, model_parameter, num_averaged: ema_decay * averaged_model_parameter + (1 - ema_decay) * model_parameter)
/home/furkan/miniconda3/envs/DiffDis/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Loading samples: 5046/5046 Loaded 5046 samples.
In [ ]:
def train(num_train_epochs, device="cuda", save_steps=1000):
global best_loss, best_step, last_loss, last_step
if train_unet_file is not None:
first_epoch = last_step // len(train_dataloader)
global_step = last_step + 1
else:
first_epoch = 0
global_step = 0
accumulator = 0
# Move models to the device
models['encoder'].to(device)
models['clip'].to(device)
models['diffusion'].to(device)
if use_ema:
ema_unet.to(device)
num_train_epochs = tqdm(range(first_epoch, num_train_epochs), desc="Epoch")
for epoch in num_train_epochs:
train_loss = 0.0
num_train_steps = len(train_dataloader)
for step, batch in enumerate(train_dataloader):
start_time = time.time()
# Extract images and texts from batch
images = batch["pixel_values"]
texts = batch["input_ids"]
# Move batch to the device
images = images.to(device)
texts = texts.to(device)
# Encode images to latent space
encoder_noise = torch.randn(images.shape[0], 4, LATENTS_HEIGHT, LATENTS_WIDTH).to(device) # Shape (BATCH_SIZE, 4, 32, 32)
latents = models['encoder'](images, encoder_noise)
# Sample noise and timesteps for diffusion process
bsz = latents.shape[0]
timesteps = torch.randint(0, ddpm.num_train_timesteps, (bsz,), device=latents.device).long()
text_timesteps = torch.randint(0, ddpm.num_train_timesteps, (bsz,), device=latents.device).long()
# Add noise to latents and texts
noisy_latents, image_noise = ddpm.add_noise(latents, timesteps)
encoder_hidden_states = models['clip'](texts)
noisy_text_query, text_noise = ddpm.add_noise(encoder_hidden_states, text_timesteps)
# Get time embeddings
image_time_embeddings = get_time_embedding(timesteps, is_image=True).to(device)
text_time_embeddings = get_time_embedding(timesteps, is_image=False).to(device)
# Average and normalize text time embeddings
average_noisy_text_query = noisy_text_query.mean(dim=1)
text_query = F.normalize(average_noisy_text_query, p=2, dim=-1)
# Randomly drop 10% of text and image conditions: Context Free Guidance
if torch.rand(1).item() < 0.1:
text_query = torch.zeros_like(text_query)
if torch.rand(1).item() < 0.1:
noisy_latents = torch.zeros_like(noisy_latents)
# Predict the noise residual and compute loss
image_pred, text_pred = models['diffusion'](noisy_latents, encoder_hidden_states, image_time_embeddings, text_time_embeddings, text_query)
image_loss = F.mse_loss(image_pred.float(), image_noise.float(), reduction="mean")
text_loss = F.mse_loss(text_pred.float(), text_query.float(), reduction="mean")
loss = image_loss + Lambda * text_loss
train_loss += loss.item()
accumulator += loss.item()
# Backpropagate
loss.backward()
optimizer.step()
optimizer.zero_grad()
scheduler.step()
if use_ema:
ema_unet.update_parameters(models['diffusion'])
end_time = time.time()
if train_unet_file is not None and epoch == first_epoch:
print(f"Step: {step+1+last_step}/{num_train_steps+last_step} Loss: {loss.item()} Time: {end_time - start_time}", end="\r")
else:
print(f"Step: {step}/{num_train_steps} Loss: {loss.item()} Time: {end_time - start_time}", end="\r")
if global_step % save_steps == 0 and global_step > 0:
# Check if the current step's loss is the best
if accumulator / save_steps < best_loss:
best_loss = accumulator / save_steps
best_step = global_step
best_save_path = os.path.join(train_output_dir, "best.pt")
if use_ema:
torch.save({
'model_state_dict': models['diffusion'].state_dict(),
'ema_state_dict': ema_unet.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_loss': best_loss,
'best_step': best_step,
'last_loss': accumulator / save_steps,
'last_step': global_step
}, best_save_path)
else:
torch.save({
'model_state_dict': models['diffusion'].state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_loss': best_loss,
'best_step': best_step,
'last_loss': accumulator / save_steps,
'last_step': global_step
}, best_save_path)
print(f"\nNew best model saved to {best_save_path} with loss {best_loss}")
# Save model and optimizer state
last_save_path = os.path.join(train_output_dir, f"last.pt")
if use_ema:
torch.save({
'model_state_dict': models['diffusion'].state_dict(),
'ema_state_dict': ema_unet.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_loss': best_loss,
'best_step': best_step,
'last_loss': accumulator / save_steps,
'last_step': global_step
}, last_save_path)
else:
torch.save({
'model_state_dict': models['diffusion'].state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_loss': best_loss,
'best_step': best_step,
'last_loss': accumulator / save_steps,
'last_step': global_step
}, last_save_path)
print(f"Saved state to {last_save_path}")
# Generate samples from the model
for i, prompt in enumerate(prompts):
# Sample images from the model
output_image = pipeline.generate(
prompt=prompt,
uncond_prompt=uncond_prompt,
input_image=None,
strength=0.9,
do_cfg=do_cfg,
cfg_scale=cfg_scale,
sampler_name=sampler,
n_inference_steps=num_inference_steps,
seed=seed,
models=models,
device=DEVICE,
idle_device=DEVICE,
tokenizer=tokenizer,
)
# Save the generated image
output_image = Image.fromarray(output_image)
# Display the generated image
display(output_image)
print(f"\nSaved images for step {global_step}")
print('Epoch: %d Step: %d Loss: %.5f Best Loss: %.5f Best Step: %d\n' % (epoch+1, global_step, accumulator / save_steps, best_loss, best_step))
accumulator = 0.0
global_step += 1
print(f"Average loss over epoch: {train_loss / (step + 1)}")
In [ ]:
s = '==> Training starts..'
s += f'\n\nModel file: {model_file}'
s += f'\nUNet file: {train_unet_file}'
s += f'\nBatch size: {BATCH_SIZE}'
s += f'\nWidth: {WIDTH}'
s += f'\nHeight: {HEIGHT}'
s += f'\nLatents width: {LATENTS_WIDTH}'
s += f'\nLatents height: {LATENTS_HEIGHT}'
s += f'\nFirst epoch: {last_step // len(train_dataloader)}'
s += f'\nNumber of training epochs: {num_train_epochs}'
s += f'\nLambda: {Lambda}'
s += f'\nLearning rate: {learning_rate}'
s += f'\nDiscriminative learning rate: {discriminative_learning_rate}'
s += f'\nAdam beta1: {adam_beta1}'
s += f'\nAdam beta2: {adam_beta2}'
s += f'\nAdam weight decay: {adam_weight_decay}'
s += f'\nAdam epsilon: {adam_epsilon}'
s += f'\nUse EMA: {use_ema}'
s += f'\nEMA decay: {ema_decay}'
s += f'\nWarmup steps: {warmup_steps}'
s += f'\nOutput directory: {train_output_dir}'
s += f'\nSave steps: {save_steps}'
s += f'\nDevice: {DEVICE}'
s += f'\nSampler: {sampler}'
s += f'\nNumber of inference steps: {num_inference_steps}'
s += f'\nSeed: {seed}'
for i, prompt in enumerate(prompts):
s += f'\nPrompt {i + 1}: {prompt}'
s += f'\nUnconditional prompt: {uncond_prompt}'
s += f'\nDo CFG: {do_cfg}'
s += f'\nCFG scale: {cfg_scale}'
s += f'\n\n'
print(s)
# Create the output directory
os.makedirs(train_output_dir, exist_ok=True)
train(num_train_epochs=num_train_epochs, device=DEVICE, save_steps=save_steps)
==> Training starts.. Model file: data/v1-5-pruned.ckpt UNet file: None Batch size: 1 Width: 512 Height: 512 Latents width: 64 Latents height: 64 First epoch: 0 Number of training epochs: 6 Lambda: 1.0 Learning rate: 1e-05 Discriminative learning rate: 0.0001 Adam beta1: 0.9 Adam beta2: 0.999 Adam weight decay: 0.0001 Adam epsilon: 1e-08 Use EMA: False EMA decay: 0.9999 Warmup steps: 1000 Output directory: ../results/output_1 Save steps: 5000 Device: cuda Sampler: ddpm Number of inference steps: 50 Seed: 42 Prompt 1: A river with boats docked and houses in the background Prompt 2: A piece of chocolate swirled cake on a plate Prompt 3: A large bed sitting next to a small Christmas Tree surrounded by pictures Prompt 4: A bear searching for food near the river Unconditional prompt: Do CFG: True CFG scale: 3
Epoch: 0%| | 0/6 [00:00<?, ?it/s]/home/furkan/miniconda3/envs/DiffDis/lib/python3.12/site-packages/torch/nn/modules/conv.py:456: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at /opt/conda/conda-bld/pytorch_1712609048481/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:84.) return F.conv2d(input, weight, bias, self.stride, /home/furkan/CENG796/code/pipeline.py:188: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). timesteps = torch.tensor(timesteps, dtype=torch.float32)[:, None] # convert the batch of timesteps to a 2-D tensor
Step: 5000/5046 Loss: 0.11019457876682281 Time: 0.1516003608703613388 New best model saved to ../results/output_1/best.pt with loss 0.2574170602272265 Saved state to ../results/output_1/last.pt
100%|██████████| 50/50 [00:04<00:00, 11.53it/s]
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
100%|██████████| 50/50 [00:04<00:00, 11.82it/s]
100%|██████████| 50/50 [00:04<00:00, 11.82it/s]
Saved images for step 5000 Epoch: 1 Step: 5000 Loss: 0.25742 Best Loss: 0.25742 Best Step: 5000 Step: 5044/5046 Loss: 0.9965981841087341 Time: 0.19762754440307617535
Epoch: 17%|█▋ | 1/6 [25:27<2:07:15, 1527.18s/it]
Average loss over epoch: 0.25731121857589336 Time: 0.1969141960144043 Step: 4954/5046 Loss: 0.17604000866413116 Time: 0.1969144344329834474 New best model saved to ../results/output_1/best.pt with loss 0.2406541142154485 Saved state to ../results/output_1/last.pt
100%|██████████| 50/50 [00:04<00:00, 11.83it/s]
100%|██████████| 50/50 [00:04<00:00, 11.82it/s]
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
100%|██████████| 50/50 [00:04<00:00, 11.76it/s]
Saved images for step 10000 Epoch: 2 Step: 10000 Loss: 0.24065 Best Loss: 0.24065 Best Step: 10000 Step: 5044/5046 Loss: 0.08904671669006348 Time: 0.1974637508392334577
Epoch: 33%|███▎ | 2/6 [49:31<1:38:33, 1478.35s/it]
Average loss over epoch: 0.24081591752830642 Time: 0.19686222076416016 Step: 4908/5046 Loss: 0.0033800562378019094 Time: 0.19713997840881348 New best model saved to ../results/output_1/best.pt with loss 0.23109131355108692 Saved state to ../results/output_1/last.pt
100%|██████████| 50/50 [00:04<00:00, 11.83it/s]
100%|██████████| 50/50 [00:04<00:00, 11.82it/s]
100%|██████████| 50/50 [00:04<00:00, 11.82it/s]
100%|██████████| 50/50 [00:04<00:00, 11.82it/s]
Saved images for step 15000 Epoch: 3 Step: 15000 Loss: 0.23109 Best Loss: 0.23109 Best Step: 15000 Step: 5044/5046 Loss: 0.13101424276828766 Time: 0.196987390518188483
Epoch: 50%|█████ | 3/6 [1:13:07<1:12:29, 1449.88s/it]
Average loss over epoch: 0.2313138278532704 Time: 0.19705843925476074 Saved state to ../results/output_1/last.pt Time: 0.197257280349731459523
100%|██████████| 50/50 [00:04<00:00, 11.82it/s]
100%|██████████| 50/50 [00:04<00:00, 11.82it/s]
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
Saved images for step 20000 Epoch: 4 Step: 20000 Loss: 0.23762 Best Loss: 0.23109 Best Step: 15000 Step: 5044/5046 Loss: 0.03353509679436684 Time: 0.2095785140991211436
Epoch: 67%|██████▋ | 4/6 [1:36:32<47:44, 1432.32s/it]
Average loss over epoch: 0.23832769447767252 Time: 0.20766687393188477 Saved state to ../results/output_1/last.pt9 Time: 0.1973145008087158282
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
Saved images for step 25000 Epoch: 5 Step: 25000 Loss: 0.23724 Best Loss: 0.23109 Best Step: 15000 Step: 5044/5046 Loss: 0.17421704530715942 Time: 0.1988165378570556664
Epoch: 83%|████████▎ | 5/6 [1:59:59<23:43, 1423.16s/it]
Average loss over epoch: 0.23611877438210113 Time: 0.19716739654541016 Saved state to ../results/output_1/last.pt8 Time: 0.1971397399902343833
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
100%|██████████| 50/50 [00:04<00:00, 11.80it/s]
100%|██████████| 50/50 [00:04<00:00, 11.80it/s]
Saved images for step 30000 Epoch: 6 Step: 30000 Loss: 0.23115 Best Loss: 0.23109 Best Step: 15000 Step: 5044/5046 Loss: 1.00068199634552 Time: 0.1973047256469726630573
Epoch: 100%|██████████| 6/6 [2:23:25<00:00, 1434.26s/it]
Average loss over epoch: 0.23221419323251438 Time: 0.19747495651245117
In [ ]:
import torch
import torch.nn.functional as F
import os
import model_loader
import time
import pipeline
from PIL import Image
from transformers import CLIPTokenizer
from IPython.display import display
# Set the device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Load the models
models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)
if inference_unet_file is not None:
# Load the UNet model
print(f"Loading UNet model from {inference_unet_file}")
models['diffusion'].load_state_dict(torch.load(inference_unet_file)['model_state_dict'])
# TEXT TO IMAGE
tokenizer = CLIPTokenizer("./data/vocab.json", merges_file="./data/merges.txt")
# Generate samples from the model
for i, prompt in enumerate(prompts):
for j in range(num_samples):
start = time.time()
# Sample images from the model
output_image = pipeline.generate(
prompt=prompt,
uncond_prompt=uncond_prompt,
input_image=None,
strength=0.9,
do_cfg=do_cfg,
cfg_scale=cfg_scale,
sampler_name=sampler,
n_inference_steps=num_inference_steps,
seed=seed,
models=models,
device=DEVICE,
idle_device=DEVICE,
tokenizer=tokenizer,
)
end = time.time()
print(f"PROMPT {i+1} - SAMPLE {j+1} - TIME: {end - start:.2f}s\n")
# Save the generated image
output_image = Image.fromarray(output_image)
# Display the generated image
display(output_image)
/home/furkan/miniconda3/envs/DiffDis/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Loading UNet model from ../results/output_1/last.pt
0%| | 0/50 [00:00<?, ?it/s]/home/furkan/CENG796/code/pipeline.py:188: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). timesteps = torch.tensor(timesteps, dtype=torch.float32)[:, None] # convert the batch of timesteps to a 2-D tensor /home/furkan/miniconda3/envs/DiffDis/lib/python3.12/site-packages/torch/nn/modules/conv.py:456: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at /opt/conda/conda-bld/pytorch_1712609048481/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:84.) return F.conv2d(input, weight, bias, self.stride, 100%|██████████| 50/50 [00:07<00:00, 6.58it/s]
PROMPT 1 - SAMPLE 1 - TIME: 7.85s
100%|██████████| 50/50 [00:07<00:00, 6.61it/s]
PROMPT 2 - SAMPLE 1 - TIME: 7.77s
100%|██████████| 50/50 [00:07<00:00, 6.66it/s]
PROMPT 3 - SAMPLE 1 - TIME: 7.71s
100%|██████████| 50/50 [00:07<00:00, 6.60it/s]
PROMPT 4 - SAMPLE 1 - TIME: 7.78s
In [ ]:
import torch
import torchvision
from torchvision import transforms
import torch.nn.functional as F
import os
from ddpm import DDPMSampler
from pipeline import get_time_embedding
import model_loader
import time
from transformers import CLIPTokenizer
# Set the device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Load the models
models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)
ddpm = DDPMSampler(generator=None)
if test_unet_file is not None:
# Load the UNet model
print(f"Loading UNet model from {test_unet_file}")
if use_ema:
models['diffusion'].load_state_dict(torch.load(test_unet_file)['ema_state_dict'])
else:
models['diffusion'].load_state_dict(torch.load(test_unet_file)['model_state_dict'])
# TEXT TO IMAGE
tokenizer = CLIPTokenizer("./data/vocab.json", merges_file="./data/merges.txt")
# Set the models['encoder'], models['clip'], models['diffusion'] to eval mode
models['encoder'].eval()
models['clip'].eval()
models['diffusion'].eval()
print("==> Testing starts..")
Loading UNet model from ../results/output_1/last.pt ==> Testing starts..
In [ ]:
def test(device="cuda"):
# Get the transform for the test data
transform = transforms.Compose([
transforms.Resize((WIDTH, HEIGHT), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
# Load the CIFAR-10 dataset
if test_dataset == "CIFAR10":
testset = torchvision.datasets.CIFAR10(
root='../dataset', train=False, download=True, transform=transform)
elif test_dataset == "CIFAR100":
testset = torchvision.datasets.CIFAR100(
root='../dataset', train=False, download=True, transform=transform)
print(f"Test dataset: {test_dataset} | Number of test samples: {len(testset)}")
# Load the test data
testloader = torch.utils.data.DataLoader(
testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
# Move models to the device
models['encoder'].to(device)
models['clip'].to(device)
models['diffusion'].to(device)
# Define the class names and tokens
class_names = testset.classes
class_tokens = []
# Tokenize class names
for class_name in class_names:
# Tokenize text
tokens = tokenizer.batch_encode_plus(
[class_name], padding="max_length", max_length=77
).input_ids
tokens = torch.tensor(tokens, dtype=torch.long).squeeze()
class_tokens.append(tokens)
# Convert list of class tokens to a tensor
class_tokens = torch.stack(class_tokens).to(device)
print(f"Class tokens shape: {class_tokens.shape}")
# Encode class tokens with the CLIP model
with torch.no_grad():
# Encode class tokens
encoder_hidden_states = models['clip'](class_tokens)
# Average and normalize class embeddings
class_embeddings = encoder_hidden_states.mean(dim=1)
class_embeddings = F.normalize(class_embeddings, p=2, dim=-1)
print(f"Class embeddings shape: {class_embeddings.shape}\n")
# Start testing
test_loss = 0.0
num_test_steps = len(testloader)
correct_predictions = 0
total_predictions = 0
with torch.no_grad():
for batch_idx, (images, targets) in enumerate(testloader):
start_time = time.time()
# Move batch to the device
images = images.to(device)
targets = targets.to(device)
texts = [class_tokens[target] for target in targets]
# Convert list of class tokens to a tensor
texts = torch.stack(texts).to(device)
# Encode images to latent space
encoder_noise = torch.randn(images.shape[0], 4, LATENTS_HEIGHT, LATENTS_WIDTH).to(device) # Shape (BATCH_SIZE, 4, 32, 32)
latents = models['encoder'](images, encoder_noise)
# Sample noise and timesteps for diffusion process
bsz = latents.shape[0]
timesteps = torch.randint(0, ddpm.num_train_timesteps, (bsz,), device=latents.device).long()
text_timesteps = torch.randint(0, ddpm.num_train_timesteps, (bsz,), device=latents.device).long()
# Add noise to latents and texts
noisy_latents, image_noise = ddpm.add_noise(latents, timesteps)
encoder_hidden_states = models['clip'](texts)
noisy_text_query, text_noise = ddpm.add_noise(encoder_hidden_states, text_timesteps)
# Get time embeddings
image_time_embeddings = get_time_embedding(timesteps, is_image=True).to(device)
text_time_embeddings = get_time_embedding(timesteps, is_image=False).to(device)
# Average and normalize text time embeddings
average_noisy_text_query = noisy_text_query.mean(dim=1)
text_query = F.normalize(average_noisy_text_query, p=2, dim=-1)
# Randomly drop 10% of text and image conditions: Context Free Guidance
if torch.rand(1).item() < 0.1:
text_query = torch.zeros_like(text_query)
if torch.rand(1).item() < 0.1:
noisy_latents = torch.zeros_like(noisy_latents)
# Predict the noise residual and compute loss
_, text_pred = models['diffusion'](noisy_latents, encoder_hidden_states, image_time_embeddings, text_time_embeddings, text_query)
# Calculate loss
loss = F.mse_loss(text_pred.float(), text_query.float(), reduction="mean")
test_loss += loss.item()
# Calculate cosine similarity between the generated text query and class embeddings
similarities = F.cosine_similarity(text_pred.unsqueeze(1), class_embeddings.unsqueeze(0), dim=-1)
predicted_classes = similarities.argmax(dim=-1)
# Compare predictions with actual targets
correct_predictions += (predicted_classes == targets).sum().item()
total_predictions += targets.size(0)
end_time = time.time()
print(f"Batch {batch_idx + 1}/{num_test_steps} | Loss: {loss:.4f} | Time: {end_time - start_time:.2f}s", end="\r")
# Calculate total accuracy
accuracy = correct_predictions / total_predictions
s = f"Accuracy: %.2f%% ({correct_predictions}/{total_predictions})" % (accuracy * 100)
s += f"\nTest Loss: {test_loss / num_test_steps:.4f}"
print("\n" + s)
In [ ]:
s = '==> Testing starts..'
s += f'\n\nTest dataset: {test_dataset}'
s += f'\nModel file: {model_file}'
s += f'\nUNet file: {test_unet_file}'
s += f'\nBatch size: {BATCH_SIZE}'
s += f'\nWidth: {WIDTH}'
s += f'\nHeight: {HEIGHT}'
s += f'\nLatents width: {LATENTS_WIDTH}'
s += f'\nLatents height: {LATENTS_HEIGHT}'
s += f'\nNumber of training epochs: {num_train_epochs}'
s += f'\nLambda: {Lambda}'
s += f'\nLearning rate: {learning_rate}'
s += f'\nDiscriminative learning rate: {discriminative_learning_rate}'
s += f'\nAdam beta1: {adam_beta1}'
s += f'\nAdam beta2: {adam_beta2}'
s += f'\nAdam weight decay: {adam_weight_decay}'
s += f'\nAdam epsilon: {adam_epsilon}'
s += f'\nUse EMA: {use_ema}'
s += f'\nEMA decay: {ema_decay}'
s += f'\nWarmup steps: {warmup_steps}'
s += f'\nOutput directory: {test_output_dir}'
s += f'\nSave steps: {save_steps}'
s += f'\nDevice: {DEVICE}'
s += f'\nSampler: {sampler}'
s += f'\nNumber of inference steps: {num_inference_steps}'
s += f'\nSeed: {seed}'
for i, prompt in enumerate(prompts):
s += f'\nPrompt {i + 1}: {prompt}'
s += f'\nUnconditional prompt: {uncond_prompt}'
s += f'\nDo CFG: {do_cfg}'
s += f'\nCFG scale: {cfg_scale}'
s += f'\n\n'
print(s)
# Test the model on the CIFAR-10 dataset
test(device=DEVICE)
==> Testing starts.. Test dataset: CIFAR10 Model file: data/v1-5-pruned.ckpt UNet file: ../results/output_1/last.pt Batch size: 1 Width: 512 Height: 512 Latents width: 64 Latents height: 64 Number of training epochs: 6 Lambda: 1.0 Learning rate: 1e-05 Discriminative learning rate: 0.0001 Adam beta1: 0.9 Adam beta2: 0.999 Adam weight decay: 0.0001 Adam epsilon: 1e-08 Use EMA: False EMA decay: 0.9999 Warmup steps: 1000 Output directory: ../results/CIFAR10 Save steps: 5000 Device: cuda Sampler: ddpm Number of inference steps: 50 Seed: 42 Prompt 1: A river with boats docked and houses in the background Prompt 2: A piece of chocolate swirled cake on a plate Prompt 3: A large bed sitting next to a small Christmas Tree surrounded by pictures Prompt 4: A bear searching for food near the river Unconditional prompt: Do CFG: True CFG scale: 3 Files already downloaded and verified Test dataset: CIFAR10 | Number of test samples: 10000 Class tokens shape: torch.Size([10, 77]) Class embeddings shape: torch.Size([10, 768]) Batch 10000/10000 | Loss: 0.0011 | Time: 0.08s Accuracy: 99.93% (9993/10000) Test Loss: 0.0007
In [ ]:
# Test the model on the CIFAR-100 dataset
test_dataset = "CIFAR100"
test_output_dir = "../results/" + test_dataset
test(device=DEVICE)
Files already downloaded and verified Test dataset: CIFAR100 | Number of test samples: 10000 Class tokens shape: torch.Size([100, 77]) Class embeddings shape: torch.Size([100, 768]) Batch 10000/10000 | Loss: 0.0008 | Time: 0.08s Accuracy: 92.74% (9274/10000) Test Loss: 0.0008